package nsalt.mem

import chisel3._
import chisel3.util._

import nsalt._
import nsalt.bus._
import nsalt.func._
import nsalt.util._


class TransLookasideExec(implicit val conf: TransLookasideConf) extends Module with TransLookasideConst with CtrlStatusRegConst{

  val io = IO(new Bundle {
    
    val in = Flipped(Decoupled(new BusReqPort(userBits = USER_BITS, addrBits = VIRT_MEM_ADDR_LEN)))
    val out = Decoupled(new BusReqPort(userBits = USER_BITS))

    val group = Input(Vec(WAYS, UInt(tlbLen.W)))
    val groupWrite = new TransLookasideTableWriteBundle(IndexBits = IndexBits, Ways = WAYS, tlbLen = tlbLen)
    val groupReady = Input(Bool())

    val mem = new BusUncached(userBits = USER_BITS)
    val flush = Input(Bool())
    val satp = Input(UInt(XLEN.W))
    val pageFault = new MemManagePort
    val ipf = Output(Bool())
    val isFinish = Output(Bool())
  })

  val group = io.group//RegEnable(groupTLB.io.tlbgroup, io.in.ready)
  
  // lazy renaming
  val req = io.in.bits
  val virtPageNum = req.addr
    .asTypeOf(vaBundle2).virtPageNum
    .asTypeOf(virtPageNumBundle)
  val pageFault = io.pageFault
  val satp = io.satp.asTypeOf(satpBundle)
  val iFetch = if(TLB_NAME == "itlb") true.B else false.B

  // pf init
  pageFault.loadPageFault := false.B
  pageFault.storePageFault := false.B
  pageFault.addr := req.addr

  // check hit or miss
  // val hitVec = VecInit(group.map(m => m.asTypeOf(tlbBundle).flag.asTypeOf(flagBundle).v && (m.asTypeOf(tlbBundle).asid === satp.asid) && MaskEQ(m.asTypeOf(tlbBundle).mask, m.asTypeOf(tlbBundle).virtPageNum, virtPageNum.asUInt))).asUInt
  val hitVec = VecInit(group.map(entry => {
    val isFlaggedV = entry
      .asTypeOf(tlbBundle).flag
      .asTypeOf(flagBundle).v
    
    val isASIDset  = entry.asTypeOf(tlbBundle).asid === satp.asid

    val maskedEqual = MaskEQ(entry.asTypeOf(tlbBundle).mask, entry.asTypeOf(tlbBundle).virtPageNum, virtPageNum.asUInt)
    
    isFlaggedV && isASIDset && maskedEqual

  })).asUInt

  
  val hit  = io.in.valid && hitVec.orR
  val miss = io.in.valid && !hitVec.orR

  val victimWaymask = if (WAYS > 1) (1.U << LFSR64()(log2Up(WAYS)-1,0)) else "b1".U
  val waymask = Mux(hit, hitVec, victimWaymask)

  val loadPageFault = WireInit(false.B)
  val storePageFault = WireInit(false.B)

  // hit
  val hitMeta = Mux1H(waymask, group).asTypeOf(tlbBundle2).meta.asTypeOf(metaBundle)
  val hitData = Mux1H(waymask, group).asTypeOf(tlbBundle2).data.asTypeOf(dataBundle)
  val hitFlag = hitMeta.flag.asTypeOf(flagBundle)
  val hitMask = hitMeta.mask

  // hit write back pte.flag
  val hitInstrPageFault = WireInit(false.B)
  val hitWriteBack = hit && 
    (!hitFlag.a || !hitFlag.d && req.isWrite()) && 
    !hitInstrPageFault && !(loadPageFault || storePageFault || io.pageFault.isPageFault())

  val hitRefillFlag = Cat(req.isWrite().asUInt, 1.U(1.W), 0.U(6.W)) | hitFlag.asUInt
  val hitWriteBackStore = RegEnable(Cat(0.U(10.W), hitData.physPageNum, 0.U(2.W), hitRefillFlag), hitWriteBack)

  // hit permission check
  val hitCheck = hit &&
    !(pageFault.priviledgeMode === ModeU && !hitFlag.u) && 
    !(pageFault.priviledgeMode === ModeS && hitFlag.u && (!pageFault.statusSUM || iFetch))
  val hitExec = hitCheck && hitFlag.x
  val hitLoad = hitCheck && (hitFlag.r || pageFault.statusMXR && hitFlag.x)
  val hitStore = hitCheck && hitFlag.w

  io.pageFault.loadPageFault := loadPageFault //RegNext(loadPageFault, init =false.B)
  io.pageFault.storePageFault := storePageFault //RegNext(storePageFault, init = false.B)

  if (TLB_NAME == "itlb") { 
    hitInstrPageFault := !hitExec && hit
  }
  if (TLB_NAME == "dtlb") { 
    loadPageFault := !hitLoad && req.isRead() && hit
    storePageFault := (!hitStore && req.isWrite() && hit)
    // AMO pagefault type will be fixed in LSU
  }

  // miss
  val s_idle :: s_memReadReq :: s_memReadResp :: s_write_pte :: s_wait_resp :: s_miss_slpf :: Nil = Enum(6)
  
  val state = RegInit(s_idle)
  val level = RegInit(Level.U(log2Up(Level).W))
  
  val memRespStore = Reg(UInt(XLEN.W))

  val missMask = WireInit("h3ffff".U(maskLen.W))
  val missMaskStore = Reg(UInt(maskLen.W))
  val missMetaRefill = WireInit(false.B)
  val missRefillFlag = WireInit(0.U(8.W))

  val memRdata = io.mem.res.bits.dataR.asTypeOf(pteBundle)
  val addrPhys = Reg(UInt(PHYS_ADDR_LEN.W))
  val afterFire = RegEnable(true.B, init = false.B, if(TLB_NAME == "itlb") io.out.fire else io.out.valid)

  //handle flush
  val needFlush = RegInit(false.B)
  val ioFlush = io.flush
  val isFlush = needFlush || ioFlush
  when (ioFlush && (state =/= s_idle)) { needFlush := true.B}
  if(TLB_NAME == "itlb"){
    when (io.out.fire && needFlush) { needFlush := false.B}
  }
  if(TLB_NAME == "dtlb"){
    when (io.out.valid && needFlush) { needFlush := false.B}
  }

  val missIPF = RegInit(false.B)

  // state machine to handle miss(ptw) and pte-writing-back
  switch (state) {
    is (s_idle) {
      when (!ioFlush && hitWriteBack) {
        state := s_write_pte
        needFlush := false.B
        afterFire := false.B
      }.elsewhen (miss && !ioFlush) {
        state := s_memReadReq
        addrPhys := physAddrApply(satp.physPageNum, virtPageNum.virtPageNum2) //
        level := Level.U
        needFlush := false.B
        afterFire := false.B
      }
    }

    is (s_memReadReq) { 
      when (isFlush) {
        state := s_idle
        needFlush := false.B
      }.elsewhen (io.mem.req.fire) { state := s_memReadResp}
    }

    is (s_memReadResp) { 
      val missflag = memRdata.flag.asTypeOf(flagBundle)
      when (io.mem.res.fire) {
        when (isFlush) {
          state := s_idle
          needFlush := false.B
        }.elsewhen (!(missflag.r || missflag.x) && (level===3.U || level===2.U)) {
          when(!missflag.v || (!missflag.r && missflag.w)) { //TODO: fix needflush
            if(TLB_NAME == "itlb") { state := s_wait_resp } else { state := s_miss_slpf }
            if(TLB_NAME == "itlb") { missIPF := true.B }
            if(TLB_NAME == "dtlb") { 
              loadPageFault := req.isRead()
              storePageFault := req.isWrite() 
            }  
            // Debug("tlbException!!! ")
            // Debug(false, p" req:${req}  Memreq:${io.mem.req}  MemResp:${io.mem.res}")
            // Debug(false, " level:%d",level)
            // Debug(false, "\n")
          }.otherwise {
            state := s_memReadReq
            addrPhys := physAddrApply(memRdata.physPageNum, Mux(level === 3.U, virtPageNum.virtPageNum1, virtPageNum.virtPageNum0))
          }
        }.elsewhen (level =/= 0.U) { //TODO: fix needFlush
          val permCheck = missflag.v && !(pageFault.priviledgeMode === ModeU && !missflag.u) && !(pageFault.priviledgeMode === ModeS && missflag.u && (!pageFault.statusSUM || iFetch))
          val permExec = permCheck && missflag.x
          val permLoad = permCheck && (missflag.r || pageFault.statusMXR && missflag.x)
          val permStore = permCheck && missflag.w
          // val updateAD = if (Settings.get("FPGAPlatform")) !missflag.a || (!missflag.d && req.isWrite()) else false.B
          val updateAD = false.B
          val updateData = Cat( 0.U(56.W), req.isWrite(), 1.U(1.W), 0.U(6.W) )
          missRefillFlag := Cat(req.isWrite(), 1.U(1.W), 0.U(6.W)) | missflag.asUInt
          memRespStore := io.mem.res.bits.dataR | updateData 
          if(TLB_NAME == "itlb") {
            when (!permExec) { missIPF := true.B ; state := s_wait_resp}
            .otherwise { 
              state := Mux(updateAD, s_write_pte, s_wait_resp)
              missMetaRefill := true.B
            }
          }
          if(TLB_NAME == "dtlb") {
            when((!permLoad && req.isRead()) || (!permStore && req.isWrite())) { 
              state := s_miss_slpf
              loadPageFault := req.isRead()
              storePageFault := req.isWrite()
            }.otherwise {
              state := Mux(updateAD, s_write_pte, s_wait_resp)
              missMetaRefill := true.B
            }
          }
          missMask := Mux(level===3.U, 0.U(maskLen.W), Mux(level===2.U, "h3fe00".U(maskLen.W), "h3ffff".U(maskLen.W)))
          missMaskStore := missMask
        }
        level := level - 1.U
      }
    }

    is (s_write_pte) {
      when (isFlush) {
        state := s_idle
        needFlush := false.B
      }.elsewhen (io.mem.req.fire) { state := s_wait_resp }
    }

    is (s_wait_resp) { 
      if(TLB_NAME == "itlb"){
        when (io.out.fire || ioFlush || afterFire){
          state := s_idle
          missIPF := false.B
          afterFire := false.B
        }
      }
      if(TLB_NAME == "dtlb"){
        state := s_idle
        missIPF := false.B
        afterFire := false.B
      }
    }

    is (s_miss_slpf) {
      state := s_idle
    }
  }

  // mem
  val command = Mux(state === s_write_pte, BusCommand.WRITE, BusCommand.READ)
  io.mem.req.bits.apply(
    addr = Mux(hitWriteBack, hitData.entryAddr, addrPhys), 
    command = command, 
    size = (if (XLEN == 64) "b11".U else "b10".U), 
    dataW =  Mux( hitWriteBack, hitWriteBackStore, memRespStore), 
    maskW = 0xff.U
  )
  io.mem.req.valid := ((state === s_memReadReq || state === s_write_pte) && !isFlush)
  io.mem.res.ready := true.B

  // tlb refill
  io.groupWrite.apply(
    valid = RegNext((missMetaRefill && !isFlush) || (hitWriteBack && state === s_idle && !isFlush), init = false.B), 
    index = RegNext(getIndex(req.addr)), 
    waymask = RegNext(waymask), 
    virtPageNum = RegNext(virtPageNum.asUInt), 
    
    // entry data
    asid = RegNext(Mux(hitWriteBack, hitMeta.asid, satp.asid)), 
    mask = RegNext(Mux(hitWriteBack, hitMask, missMask)), 
    flag = RegNext(Mux(hitWriteBack, hitRefillFlag, missRefillFlag)),
    physPageNum = RegNext(Mux(hitWriteBack, hitData.physPageNum, memRdata.physPageNum)), 
    entryAddr = RegNext((Mux(hitWriteBack, hitData.entryAddr, addrPhys))))

  // io
  io.out.bits := req
  io.out.bits.addr := Mux(hit, maskPaddr(hitData.physPageNum, req.addr(PHYS_ADDR_LEN-1, 0), hitMask), maskPaddr(memRespStore.asTypeOf(pteBundle).physPageNum, req.addr(PHYS_ADDR_LEN-1, 0), missMaskStore))
  io.out.valid := io.in.valid && Mux(hit && !hitWriteBack, !(io.pageFault.isPageFault() || loadPageFault || storePageFault), state === s_wait_resp)// && !afterFire
  
  io.in.ready := io.out.ready && (state === s_idle) && !miss && !hitWriteBack && io.groupReady && (!io.pageFault.isPageFault() && !loadPageFault && !storePageFault)//maybe be optimized

  io.ipf := Mux(hit, hitInstrPageFault, missIPF)
  io.isFinish := io.out.fire || io.pageFault.isPageFault()

  if(TLB_NAME == "dtlb") {
    io.isFinish := io.out.valid || io.pageFault.isPageFault()
    io.out.valid := io.in.valid && (Mux(hit && !hitWriteBack, true.B, state === s_wait_resp) || loadPageFault || storePageFault)// && !afterFire
  }
}


